1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package com.google.common.math;
18
19 import static com.google.common.base.Preconditions.checkArgument;
20 import static com.google.common.base.Preconditions.checkNotNull;
21 import static com.google.common.math.MathPreconditions.checkNonNegative;
22 import static com.google.common.math.MathPreconditions.checkPositive;
23 import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
24 import static java.math.RoundingMode.CEILING;
25 import static java.math.RoundingMode.FLOOR;
26 import static java.math.RoundingMode.HALF_EVEN;
27
28 import com.google.common.annotations.GwtCompatible;
29 import com.google.common.annotations.GwtIncompatible;
30 import com.google.common.annotations.VisibleForTesting;
31
32 import java.math.BigDecimal;
33 import java.math.BigInteger;
34 import java.math.RoundingMode;
35 import java.util.ArrayList;
36 import java.util.List;
37
38
39
40
41
42
43
44
45
46
47
48
49
50 @GwtCompatible(emulated = true)
51 public final class BigIntegerMath {
52
53
54
55 public static boolean isPowerOfTwo(BigInteger x) {
56 checkNotNull(x);
57 return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
58 }
59
60
61
62
63
64
65
66
67 @SuppressWarnings("fallthrough")
68
69 public static int log2(BigInteger x, RoundingMode mode) {
70 checkPositive("x", checkNotNull(x));
71 int logFloor = x.bitLength() - 1;
72 switch (mode) {
73 case UNNECESSARY:
74 checkRoundingUnnecessary(isPowerOfTwo(x));
75 case DOWN:
76 case FLOOR:
77 return logFloor;
78
79 case UP:
80 case CEILING:
81 return isPowerOfTwo(x) ? logFloor : logFloor + 1;
82
83 case HALF_DOWN:
84 case HALF_UP:
85 case HALF_EVEN:
86 if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
87 BigInteger halfPower = SQRT2_PRECOMPUTED_BITS.shiftRight(
88 SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
89 if (x.compareTo(halfPower) <= 0) {
90 return logFloor;
91 } else {
92 return logFloor + 1;
93 }
94 }
95
96
97
98
99
100
101 BigInteger x2 = x.pow(2);
102 int logX2Floor = x2.bitLength() - 1;
103 return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
104
105 default:
106 throw new AssertionError();
107 }
108 }
109
110
111
112
113
114
115 @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
116
117 @VisibleForTesting static final BigInteger SQRT2_PRECOMPUTED_BITS =
118 new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
119
120
121
122
123
124
125
126
127 @GwtIncompatible("TODO")
128 @SuppressWarnings("fallthrough")
129 public static int log10(BigInteger x, RoundingMode mode) {
130 checkPositive("x", x);
131 if (fitsInLong(x)) {
132 return LongMath.log10(x.longValue(), mode);
133 }
134
135 int approxLog10 = (int) (log2(x, FLOOR) * LN_2 / LN_10);
136 BigInteger approxPow = BigInteger.TEN.pow(approxLog10);
137 int approxCmp = approxPow.compareTo(x);
138
139
140
141
142
143
144 if (approxCmp > 0) {
145
146
147
148
149
150 do {
151 approxLog10--;
152 approxPow = approxPow.divide(BigInteger.TEN);
153 approxCmp = approxPow.compareTo(x);
154 } while (approxCmp > 0);
155 } else {
156 BigInteger nextPow = BigInteger.TEN.multiply(approxPow);
157 int nextCmp = nextPow.compareTo(x);
158 while (nextCmp <= 0) {
159 approxLog10++;
160 approxPow = nextPow;
161 approxCmp = nextCmp;
162 nextPow = BigInteger.TEN.multiply(approxPow);
163 nextCmp = nextPow.compareTo(x);
164 }
165 }
166
167 int floorLog = approxLog10;
168 BigInteger floorPow = approxPow;
169 int floorCmp = approxCmp;
170
171 switch (mode) {
172 case UNNECESSARY:
173 checkRoundingUnnecessary(floorCmp == 0);
174
175 case FLOOR:
176 case DOWN:
177 return floorLog;
178
179 case CEILING:
180 case UP:
181 return floorPow.equals(x) ? floorLog : floorLog + 1;
182
183 case HALF_DOWN:
184 case HALF_UP:
185 case HALF_EVEN:
186
187 BigInteger x2 = x.pow(2);
188 BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
189 return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
190 default:
191 throw new AssertionError();
192 }
193 }
194
195 private static final double LN_10 = Math.log(10);
196 private static final double LN_2 = Math.log(2);
197
198
199
200
201
202
203
204
205 @GwtIncompatible("TODO")
206 @SuppressWarnings("fallthrough")
207 public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
208 checkNonNegative("x", x);
209 if (fitsInLong(x)) {
210 return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
211 }
212 BigInteger sqrtFloor = sqrtFloor(x);
213 switch (mode) {
214 case UNNECESSARY:
215 checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x));
216 case FLOOR:
217 case DOWN:
218 return sqrtFloor;
219 case CEILING:
220 case UP:
221 int sqrtFloorInt = sqrtFloor.intValue();
222 boolean sqrtFloorIsExact =
223 (sqrtFloorInt * sqrtFloorInt == x.intValue())
224 && sqrtFloor.pow(2).equals(x);
225 return sqrtFloorIsExact ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
226 case HALF_DOWN:
227 case HALF_UP:
228 case HALF_EVEN:
229 BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
230
231
232
233
234
235 return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
236 default:
237 throw new AssertionError();
238 }
239 }
240
241 @GwtIncompatible("TODO")
242 private static BigInteger sqrtFloor(BigInteger x) {
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262 BigInteger sqrt0;
263 int log2 = log2(x, FLOOR);
264 if (log2 < Double.MAX_EXPONENT) {
265 sqrt0 = sqrtApproxWithDoubles(x);
266 } else {
267 int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1;
268
269
270
271
272 sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
273 }
274 BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
275 if (sqrt0.equals(sqrt1)) {
276 return sqrt0;
277 }
278 do {
279 sqrt0 = sqrt1;
280 sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
281 } while (sqrt1.compareTo(sqrt0) < 0);
282 return sqrt0;
283 }
284
285 @GwtIncompatible("TODO")
286 private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
287 return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
288 }
289
290
291
292
293
294
295
296
297 @GwtIncompatible("TODO")
298 public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode) {
299 BigDecimal pDec = new BigDecimal(p);
300 BigDecimal qDec = new BigDecimal(q);
301 return pDec.divide(qDec, 0, mode).toBigIntegerExact();
302 }
303
304
305
306
307
308
309
310
311
312
313
314
315
316 public static BigInteger factorial(int n) {
317 checkNonNegative("n", n);
318
319
320 if (n < LongMath.factorials.length) {
321 return BigInteger.valueOf(LongMath.factorials[n]);
322 }
323
324
325 int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
326 ArrayList<BigInteger> bignums = new ArrayList<BigInteger>(approxSize);
327
328
329 int startingNumber = LongMath.factorials.length;
330 long product = LongMath.factorials[startingNumber - 1];
331
332 int shift = Long.numberOfTrailingZeros(product);
333 product >>= shift;
334
335
336 int productBits = LongMath.log2(product, FLOOR) + 1;
337 int bits = LongMath.log2(startingNumber, FLOOR) + 1;
338
339 int nextPowerOfTwo = 1 << (bits - 1);
340
341
342 for (long num = startingNumber; num <= n; num++) {
343
344 if ((num & nextPowerOfTwo) != 0) {
345 nextPowerOfTwo <<= 1;
346 bits++;
347 }
348
349 int tz = Long.numberOfTrailingZeros(num);
350 long normalizedNum = num >> tz;
351 shift += tz;
352
353 int normalizedBits = bits - tz;
354
355 if (normalizedBits + productBits >= Long.SIZE) {
356 bignums.add(BigInteger.valueOf(product));
357 product = 1;
358 productBits = 0;
359 }
360 product *= normalizedNum;
361 productBits = LongMath.log2(product, FLOOR) + 1;
362 }
363
364 if (product > 1) {
365 bignums.add(BigInteger.valueOf(product));
366 }
367
368 return listProduct(bignums).shiftLeft(shift);
369 }
370
371 static BigInteger listProduct(List<BigInteger> nums) {
372 return listProduct(nums, 0, nums.size());
373 }
374
375 static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
376 switch (end - start) {
377 case 0:
378 return BigInteger.ONE;
379 case 1:
380 return nums.get(start);
381 case 2:
382 return nums.get(start).multiply(nums.get(start + 1));
383 case 3:
384 return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
385 default:
386
387 int m = (end + start) >>> 1;
388 return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
389 }
390 }
391
392
393
394
395
396
397
398
399
400 public static BigInteger binomial(int n, int k) {
401 checkNonNegative("n", n);
402 checkNonNegative("k", k);
403 checkArgument(k <= n, "k (%s) > n (%s)", k, n);
404 if (k > (n >> 1)) {
405 k = n - k;
406 }
407 if (k < LongMath.biggestBinomials.length && n <= LongMath.biggestBinomials[k]) {
408 return BigInteger.valueOf(LongMath.binomial(n, k));
409 }
410
411 BigInteger accum = BigInteger.ONE;
412
413 long numeratorAccum = n;
414 long denominatorAccum = 1;
415
416 int bits = LongMath.log2(n, RoundingMode.CEILING);
417
418 int numeratorBits = bits;
419
420 for (int i = 1; i < k; i++) {
421 int p = n - i;
422 int q = i + 1;
423
424
425
426 if (numeratorBits + bits >= Long.SIZE - 1) {
427
428
429 accum = accum
430 .multiply(BigInteger.valueOf(numeratorAccum))
431 .divide(BigInteger.valueOf(denominatorAccum));
432 numeratorAccum = p;
433 denominatorAccum = q;
434 numeratorBits = bits;
435 } else {
436
437 numeratorAccum *= p;
438 denominatorAccum *= q;
439 numeratorBits += bits;
440 }
441 }
442 return accum
443 .multiply(BigInteger.valueOf(numeratorAccum))
444 .divide(BigInteger.valueOf(denominatorAccum));
445 }
446
447
448 @GwtIncompatible("TODO")
449 static boolean fitsInLong(BigInteger x) {
450 return x.bitLength() <= Long.SIZE - 1;
451 }
452
453 private BigIntegerMath() {}
454 }